package LiuYun

import spinal.core._
import spinal.lib._

trait BTB{
    val taken :Bool
    val target:UInt
}

object BTBConfig{
    def Count :Int = 64
    def Tagwidth :Int = 12
    //def Savewidth :Int = 2 + log2Up(BTBConfig.Count)  //hit ## taken ## index
    def Ghrwidth :Int = 12
    def RasCount :Int = 4
}

class BTBData(val count:Int) extends Bundle with IMasterSlave{
    val hit = Bool()
    val index = UInt(log2Up(count) bits)

    override def asMaster():Unit = {
        out(index, hit)
    }
}

class Ras extends Bundle with IMasterSlave{
    //val push :Bool = Bool()
    val target :UInt = LISA.GPR
    override def asMaster():Unit = {
        out(target)
    }
}

class BTBLookup extends Bundle with IMasterSlave{
    val pc    :UInt = LISA.GPR
    val ghr = Bits(BTBConfig.Ghrwidth bits)
    val valid :Bool = Bool()
    val taken :Bool = Bool()
    val target:UInt = LISA.GPR
    val data  :BTBData = new BTBData(BTBConfig.Count)
    
    override def asMaster():Unit = {
        in(pc,ghr,valid)
        out(taken, target)
        master(data)
    }
}

class BTBUpdate extends Bundle with IMasterSlave{
    val valid   :Bool =  Bool()
    val pc      :UInt =  LISA.GPR
    val ghr = Bits(BTBConfig.Ghrwidth bits)
    val taken   :Bool =  Bool()
    val target  :UInt =  LISA.GPR
    val brutype :Bits =  Bits(2 bits) //00 for bl, o1 for jump, 10 for cond jump, 11 for link
    val data    :BTBData = new BTBData(BTBConfig.Count)

    override def asMaster():Unit = {
        out(pc, ghr, taken, target, valid, brutype)
        master(data)
    }
}

class NoneBTB extends Component{
    val io = new Bundle{
        val lookup = master(new BTBLookup)
        val update = slave(new BTBUpdate)
        val ras = slave(new Ras)
    }

    io.lookup.data.hit := False
    io.lookup.data.index := U(0)
    io.lookup.taken := False
    io.lookup.target :=  U(0)
}

//---------------------------------------------------------------------------------------------------------------------

class SimpleBTBentry(val w:Int) extends Bundle with BTB{
    val taken :Bool = Bool()
    val target:UInt = LISA.GPR
    val tag   :UInt = UInt(BTBConfig.Tagwidth bits)
}

class SimpleBTB extends Component{
    val io = new Bundle{
        val lookup = master(new BTBLookup)
        val update = slave(new BTBUpdate)
        val ras = slave(new Ras)
    }
    val record = Reg(UInt(log2Up(BTBConfig.Count) bits)) init(0) 
    //for insert and replace BTBentry

    val table = Reg(Vec(new SimpleBTBentry(BTBConfig.Tagwidth), BTBConfig.Count))
    val entryvalid = Reg(Bits(BTBConfig.Count bits)) init(0)
    val lookuptag = io.lookup.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt
    val updatetag = io.update.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt
    val onehotvec = Vec(Bool,BTBConfig.Count)
    for(i <- 0 until BTBConfig.Count){
        when(table(i).tag === lookuptag && entryvalid(i)){
            onehotvec(i) := True
        }.otherwise{
            onehotvec(i) := False
        }
    }

    val hit :Bool= onehotvec.reduce(_|_)
    val index :UInt= OHToUInt(onehotvec)
    io.lookup.data.hit := hit
    io.lookup.data.index := index
    io.lookup.taken := table(index).taken
    io.lookup.target :=  table(index).target

    when(io.update.valid){
        when(io.update.data.hit){
            table(io.update.data.index).taken := io.update.taken
            table(io.update.data.index).target := io.update.target
        }.otherwise{
            table(record).tag := updatetag
            table(record).taken := io.update.taken
            table(record).target := io.update.target
            record := record + U(1)
            entryvalid(record) := True
        }
    }
}

//---------------------------------------------------------------------------------------------------------------------

class NormalBTBentry(val w:Int) extends Bundle{
    val sc :UInt = UInt(2 bits)
    val target:UInt = LISA.GPR
    val tag   :UInt = UInt(BTBConfig.Tagwidth bits)
}

class NormalBTB extends Component{
    val io = new Bundle{
        val lookup = master(new BTBLookup)
        val update = slave(new BTBUpdate)
        val ras = slave(new Ras)
    }
    val record = Reg(UInt(log2Up(BTBConfig.Count) bits)) init(0) 
    //for insert and replace BTBentry

    val table = Reg(Vec(new NormalBTBentry(BTBConfig.Tagwidth), BTBConfig.Count))
    val entryvalid = Reg(Bits(BTBConfig.Count bits)) init(0)
    val lookuptag = io.lookup.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt //TODO
    val updatetag = io.update.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt //TODO
    val onehotvec = Vec(Bool,BTBConfig.Count)
    for(i <- 0 until BTBConfig.Count){
        when(table(i).tag === lookuptag && entryvalid(i)){
            onehotvec(i) := True
        }.otherwise{
            onehotvec(i) := False
        }
    }

    val hit :Bool= onehotvec.reduce(_|_)
    val index :UInt= OHToUInt(onehotvec)
    io.lookup.data.hit := hit
    io.lookup.data.index := index
    io.lookup.taken := Mux(hit, table(index).sc(1), False)
    io.lookup.target :=  table(index).target

    when(io.update.valid && io.update.brutype(1)){ //cond jump
        when(io.update.data.hit){
            val old_sc = table(io.update.data.index).sc
            val new_sc = Mux(io.update.taken, old_sc +| U(1), old_sc -| U(1))
            table(io.update.data.index).sc := new_sc
            table(io.update.data.index).target := io.update.target
        }.otherwise{
            table(record).tag := updatetag
            table(record).sc := U(2) //initiate to weak taken
            table(record).target := io.update.target
            record := record + U(1)
            entryvalid(record) := True
        }
    }.elsewhen(io.update.valid && io.update.brutype(0)){ //jump
        when(io.update.data.hit){
            table(io.update.data.index).sc === U(3)
            table(io.update.data.index).target := io.update.target
        }.otherwise{
            table(record).tag := updatetag
            table(record).sc := U(3)
            table(record).target := io.update.target
            record := record + U(1)
            entryvalid(record) := True
        }
    }
}

//---------------------------------------------------------------------------------------------------------------------

class ComplexBTBentry(val w:Int) extends Bundle{
    val target:UInt = LISA.GPR
    val tag   :UInt = UInt(BTBConfig.Tagwidth bits)
    val cond  :Bool = Bool()
    val ras   :Bool = Bool()
    val bl    :Bool = Bool()
}

object GShare{
    def Count = 1024
    def Width = log2Up(Count)
    def I_low = 2
    def I_high = log2Up(Count)+1

    def lookup(table:Vec[GShareEntry], pc:UInt, ghr:Bits): Bool ={

        val index = THR.xor(pc, ghr, Width)
        val res = table(index).sc(1)
        return res
    }

    def update(table:Vec[GShareEntry], pc:UInt, ghr:Bits, taken:Bool): Unit ={

        val index = THR.xor(pc, ghr, Width)
        val old_sc =  table(index).sc
        val new_sc = Mux(taken, old_sc +| U(1), old_sc -| U(1))
        table(index).sc := new_sc 
    }
}

class GShareEntry extends Bundle{
        //val valid:Bool = Bool()
        //val tag:UInt = UInt(GShare.Tagwidth bits)
        val sc:UInt = UInt(2 bits)

        def setEmpty:Unit ={
            //this.valid := False
            this.sc  := U(0)
        }
}

class ComplexBTB extends Component{
    val io = new Bundle{
        val lookup = master(new BTBLookup)
        val update = slave(new BTBUpdate)
        val ras = slave(new Ras)
    }
    val record = Reg(UInt(log2Up(BTBConfig.Count) bits)) init(0) 
    //for insert and replace BTBentry

    val table = Reg(Vec(new ComplexBTBentry(BTBConfig.Tagwidth), BTBConfig.Count))
    val entryvalid = Reg(Bits(BTBConfig.Count bits)) init(0)
    val lookuptag = io.lookup.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt
    val updatetag = io.update.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt
    val onehotvec = Vec(Bool,BTBConfig.Count)
    for(i <- 0 until BTBConfig.Count){
        when(table(i).tag === lookuptag && entryvalid(i)){
            onehotvec(i) := True
        }.otherwise{
            onehotvec(i) := False
        }
    }

    val ras = new Area{
        val stack = Vec(Reg(LISA.GPR), BTBConfig.RasCount) 
        val top = Reg(UInt(log2Up(BTBConfig.RasCount) bits)) init(0)
        val value :UInt = Mux(top === U(0), stack(BTBConfig.RasCount - 1), stack(top - U(1)))
        val push = Bool()
        val pop = Bool()
        when(push && io.lookup.valid){
            stack(top) := io.ras.target
            top := top + U(1)
        }

        when(pop && io.lookup.valid){
            top := top - U(1)
        }
    }

    val predictor = new Area{
        val Empty = new GShareEntry
        Empty.setEmpty
        val table = Vec(Reg(new GShareEntry) init(Empty), GShare.Count)
        // val lookupindex = THR.xor(io.lookup.pc, io.lookup.ghr, GShare.Width) //for debug
        // val updateindex = THR.xor(io.update.pc, io.update.ghr, GShare.Width) //for debug
        val taken :Bool = GShare.lookup(table, io.lookup.pc, io.lookup.ghr)
        when(io.update.valid){
            GShare.update(table, io.update.pc, io.update.ghr, io.update.taken)
        }
    }

/* lookup logic */
    val hit :Bool = onehotvec.reduce(_|_)
    val index :UInt = OHToUInt(onehotvec)
    val is_bl  :Bool = table(index).bl
    val is_cond :Bool = table(index).cond
    val is_ras  :Bool = table(index).ras
    val target  :UInt = table(index).target
    val hitTaken :Bool = Mux(is_cond, predictor.taken, True)
    io.lookup.data.hit := hit
    io.lookup.data.index := index
    io.lookup.taken := Mux(hit, hitTaken, False)
    io.lookup.target :=  Mux(is_ras, ras.value, target)
    ras.pop := is_ras
    ras.push := is_bl


/* update logic */
    when(io.update.valid){
        when(io.update.data.hit){
            table(io.update.data.index).ras := io.update.brutype === B(3)
            table(io.update.data.index).cond := io.update.brutype === B(2)
            table(record).bl := io.update.brutype === B(0)
            table(io.update.data.index).target := io.update.target //valid only when brutype === B(1)
        }.otherwise{
            table(record).tag := updatetag
            table(record).ras := io.update.brutype === B(3)
            table(record).cond := io.update.brutype === B(2)
            table(record).bl := io.update.brutype === B(0)
            table(record).target := io.update.target //valid only when brutype === B(1)
            entryvalid(record) := True
            record := record + U(1)

        }
    }

}

//---------------------------------------------------------------------------------------------------------------------

object TLAP{ //Two Level Adaptive Training Branch Predictor
    def Count = 64
    def history_len = 4
    def sc_count = 1 << history_len
    def Width = log2Up(Count)
    def I_low = 2
    def I_high = log2Up(Count)+1

    def lookup(table:Vec[TLAPEntry], pc:UInt, ghr:Bits): Bool ={

        val index = THR.xor(pc, Width)
        val entry = table(index)
        val ph = entry.ph
        val res = entry.sc(ph)(1)
        return res
    }

    def update(table:Vec[TLAPEntry], pc:UInt, ghr:Bits, taken:Bool): Unit ={

        val index = THR.xor(pc, Width)
        val entry = table(index)
        val ph = entry.ph
        val old_sc =  entry.sc(ph)
        val new_sc = Mux(taken, old_sc +| U(1), old_sc -| U(1))
        entry.sc(ph) := new_sc

        val new_ph = (ph |<< 1 | taken.asUInt.resize(history_len))
        entry.ph := new_ph
    }
}

class TLAPEntry extends Bundle{
        //val valid:Bool = Bool()
        //val tag:UInt = UInt(GShare.Tagwidth bits)
        val ph:UInt = UInt(TLAP.history_len bits)  //pattern history, or local history
        val sc = Vec(UInt(2 bits), TLAP.sc_count)

        def setEmpty:Unit ={
            //this.valid := False
            this.ph  := U(0)
            for(i <- 0 until TLAP.sc_count){
                this.sc(i)  := U(0)
            }
        }
}

class C2BTB extends Component{
    val io = new Bundle{
        val lookup = master(new BTBLookup)
        val update = slave(new BTBUpdate)
        val ras = slave(new Ras)
    }
    val record = Reg(UInt(log2Up(BTBConfig.Count) bits)) init(0)
    //for insert and replace BTBentry

    val table = Reg(Vec(new ComplexBTBentry(BTBConfig.Tagwidth), BTBConfig.Count))
    val entryvalid = Reg(Bits(BTBConfig.Count bits)) init(0)
    val lookuptag = io.lookup.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt
    val updatetag = io.update.pc.asBits(BTBConfig.Tagwidth + 1 downto 2).asUInt
    val onehotvec = Vec(Bool,BTBConfig.Count)
    for(i <- 0 until BTBConfig.Count){
        when(table(i).tag === lookuptag && entryvalid(i)){
            onehotvec(i) := True
        }.otherwise{
            onehotvec(i) := False
        }
    }

    val ras = new Area{
        val stack = Vec(Reg(LISA.GPR), BTBConfig.RasCount)
        val top = Reg(UInt(log2Up(BTBConfig.RasCount) bits)) init(0)
        val value :UInt = Mux(top === U(0), stack(BTBConfig.RasCount - 1), stack(top - U(1)))
        val push = Bool()
        val pop = Bool()
        when(push && io.lookup.valid){
            stack(top) := io.ras.target
            top := top + U(1)
        }

        when(pop && io.lookup.valid){
            top := top - U(1)
        }
    }

    val predictor = new Area{
        val Empty = new TLAPEntry
        Empty.setEmpty
        val table = Vec(Reg(new TLAPEntry) init(Empty), TLAP.Count)
        val lookupindex = THR.xor(io.lookup.pc, TLAP.Width) //for debug
        val updateindex = THR.xor(io.update.pc, TLAP.Width) //for debug
        val taken :Bool = TLAP.lookup(table, io.lookup.pc, io.lookup.ghr)
        when(io.update.valid){
            TLAP.update(table, io.update.pc, io.update.ghr, io.update.taken)
        }
    }

/* lookup logic */
    val hit :Bool = onehotvec.reduce(_|_)
    val index :UInt = OHToUInt(onehotvec)
    val is_bl  :Bool = table(index).bl
    val is_cond :Bool = table(index).cond
    val is_ras  :Bool = table(index).ras
    val target  :UInt = table(index).target
    val hitTaken :Bool = Mux(is_cond, predictor.taken, True)
    io.lookup.data.hit := hit
    io.lookup.data.index := index
    io.lookup.taken := Mux(hit, hitTaken, False)
    io.lookup.target :=  Mux(is_ras, ras.value, target)
    ras.pop := is_ras
    ras.push := is_bl


/* update logic */
    when(io.update.valid){
        when(io.update.data.hit){
            table(io.update.data.index).ras := io.update.brutype === B(3)
            table(io.update.data.index).cond := io.update.brutype === B(2)
            table(record).bl := io.update.brutype === B(0)
            table(io.update.data.index).target := io.update.target //valid only when brutype === B(1)
        }.otherwise{
            table(record).tag := updatetag
            table(record).ras := io.update.brutype === B(3)
            table(record).cond := io.update.brutype === B(2)
            table(record).bl := io.update.brutype === B(0)
            table(record).target := io.update.target //valid only when brutype === B(1)
            entryvalid(record) := True
            record := record + U(1)

        }
    }

}